add device generalization support for distributed tests (#156796)

MOTIVATION
To generalize Distributed test cases for non-CUDA devices

CHANGES

- test/distributed/checkpoint/test_fsspec.py
- test/distributed/checkpoint/test_state_dict.py
- test/distributed/test_multi_threaded_pg.py

Replaced hard coded device names with torch.accelerator.current_accelerator

- torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py

support for hccl backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156796
Approved by: https://github.com/guangyey, https://github.com/ezyang
This commit is contained in:
Hari Krishna Sai Kodali
2025-07-16 09:37:00 +00:00
committed by PyTorch MergeBot
parent ea74fdd24a
commit 9d184bda2f
4 changed files with 68 additions and 53 deletions

View File

@ -18,7 +18,10 @@ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_
from torch.distributed.checkpoint.utils import CheckpointException
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
@ -26,6 +29,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
BACKEND = torch.distributed.get_default_backend_for_device(device_type)
def with_temp_dir(
func: Optional[Callable] = None,
) -> Optional[Callable]:
@ -75,14 +82,14 @@ class TestFSSpec(ShardedTensorTestBase):
def world_size(self) -> int:
return 2
@with_comms(init_rpc=False)
@with_comms(backend=BACKEND, init_rpc=False)
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
@requires_nccl()
@with_temp_dir
def test_fsspec(self):
CHECKPOINT_DIR = self.temp_dir
model = FSDP(MyTestModule().cuda())
model = FSDP(MyTestModule().to(device_type))
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(torch.rand(8, 8, device=dist.get_rank())).sum().backward()
optim.step()
@ -99,7 +106,7 @@ class TestFSSpec(ShardedTensorTestBase):
planner=dcp.DefaultSavePlanner(),
)
model_2 = FSDP(MyTestModule().cuda())
model_2 = FSDP(MyTestModule().to(device_type))
optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
with FSDP.summon_full_params(model):
@ -149,9 +156,9 @@ class TestFSSpec(ShardedTensorTestBase):
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
)
@with_comms(init_rpc=False)
@with_comms(backend=BACKEND, init_rpc=False)
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
@requires_nccl()
@with_temp_dir
def test_overwrite(self):
t1, t2 = torch.randn(10), torch.randn(10)

View File

@ -62,6 +62,9 @@ from torch.testing._internal.distributed.common_state_dict import (
from torch.utils._pytree import tree_all, tree_all_only
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -79,7 +82,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.accelerator.device_count())
def _test_save_load(
self,
@ -101,7 +104,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
for d_optim in _dist_optim:
d_optim.zero_grad()
batch = torch.rand(8, 100, device="cuda")
batch = torch.rand(8, 100, device=device_type)
model(batch).sum().backward()
dist_model(batch).sum().backward()
@ -188,9 +191,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def init_model_optim():
if use_dtensor:
device_mesh = init_device_mesh("cuda", (self.world_size,))
device_mesh = init_device_mesh(device_type, (self.world_size,))
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_model = CompositeParamModel(device=torch.device(device_type))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
if wrapping:
@ -198,7 +201,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
else:
strategy = {UnitModule}
if use_dtensor:
device_mesh = init_device_mesh("cuda", (self.world_size,))
device_mesh = init_device_mesh(device_type, (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy(strategy),
@ -258,7 +261,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
foreach: bool = True,
):
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_model = CompositeParamModel(device=torch.device(device_type))
orig_optim = optimizer_class(
orig_model.parameters(), lr=1e-4, foreach=foreach
)
@ -295,7 +298,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_model = CompositeParamModel(device=torch.device(device_type))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
if use_composable:
@ -329,7 +332,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
test_frozen: bool = False,
) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_model = CompositeParamModel(device=torch.device(device_type))
if test_frozen:
for param in chain(
orig_model.u1.parameters(), orig_model.u2.parameters()
@ -370,7 +373,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_model = CompositeParamModel(device=torch.device(device_type))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
model_copy = copy.deepcopy(orig_model)
@ -385,7 +388,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_single_gpu(torch.optim.AdamW)
def _test_strict(self, parallelism: str) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
if parallelism == "DDP":
model = DDP(model)
else:
@ -422,8 +425,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_cpu_offload_full_state_dict(
self, optimizer_class: type[Optimizer]
) -> None:
orig_model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = CompositeParamModel(device=torch.device(device_type))
device_mesh = init_device_mesh(device_type, (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
@ -499,7 +502,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@skip_if_lt_x_gpu(1)
def test_activation_ckpt_fqns_ddp(self) -> None:
"""Tests that activation checkpointing prefixes are removed from module names"""
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
original_keys = get_model_state_dict(model).keys()
apply_activation_checkpointing(model)
@ -518,7 +521,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
"""Tests that activation checkpointing prefixes are removed from module names"""
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
original_keys = get_model_state_dict(model).keys()
apply_activation_checkpointing(model)
@ -529,7 +532,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@skip_if_lt_x_gpu(1)
def test_extra_state(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
def get_extra_state(self):
return "MyState"
@ -547,21 +550,21 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@skip_if_lt_x_gpu(1)
def test_non_persistent_buffers(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
model.register_buffer(
"dont_save_me", torch.rand(100, device="cuda"), persistent=False
"dont_save_me", torch.rand(100, device=device_type), persistent=False
)
target_model = copy.deepcopy(model)
set_model_state_dict(target_model, get_model_state_dict(target_model))
self.assertEqual(model.state_dict(), get_model_state_dict(target_model))
def _test_broadcast_from_rank0(self, wrapper) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
optim = torch.optim.Adam(model.parameters())
fsdp_model = wrapper(copy.deepcopy(model))
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
batch = torch.rand(8, 100, device="cuda")
batch = torch.rand(8, 100, device=device_type)
model(batch).sum().backward()
optim.step()
states, optim_states = get_state_dict(model, optim)
@ -631,8 +634,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@with_comms
@skip_if_lt_x_gpu(4)
def test_broadcast_from_rank0(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
hsdp_device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
device_mesh = init_device_mesh(device_type, (self.world_size,))
hsdp_device_mesh = init_device_mesh(device_type, (2, self.world_size // 2))
self.run_subtests(
{
"wrapper": [
@ -654,8 +657,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
# This test verifies that FSDP root is not initialized but we should
# still be able to get the state_dict without errors because
# fsdp_model.state_dict() will trigger the FSDP initialization.
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh(device_type, (self.world_size,))
model = CompositeParamModel(device=torch.device(device_type))
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
get_model_state_dict(fsdp_model)
@ -668,10 +671,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
# "initial_lr" is added to optim_state_dict, but not to the new optim
# We test whether "initial_lr" appear in optim after
# set_optimizer_state_dict.
device = "cuda"
torch.manual_seed(0)
model = nn.Sequential(
*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]
*[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)]
)
for layer in model:
fully_shard(layer)
@ -705,11 +707,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
@with_comms
@skip_if_lt_x_gpu(2)
def test_flattened_osd(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh(device_type, (self.world_size,))
model = CompositeParamModel(device=torch.device(device_type))
fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh)
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters())
batch = torch.rand(8, 100, device="cuda")
batch = torch.rand(8, 100, device=device_type)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
fsdp_optim.zero_grad()
@ -730,7 +732,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
def _test_deprecate_partial(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
model_state_dict1 = get_model_state_dict(model)
model_state_dict1 = copy.deepcopy(model_state_dict1)
@ -783,8 +785,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
def _test_deprecate_fsdp_api(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh(device_type, (self.world_size,))
model = CompositeParamModel(device=torch.device(device_type))
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
with self.assertWarnsRegex(
FutureWarning,
@ -823,8 +825,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
return output
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
device_mesh = init_device_mesh(device_type, (self.world_size,))
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device(device_type))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
@ -905,8 +907,12 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
# Memory allocated and reserved are lower due to the change at _distribute_tensors
# from view to clone. This test would fail if with view due to higher memory cost.
memory_allocated = torch.cuda.memory_allocated(0) / 1024 / 1024
memory_reserved = torch.cuda.memory_reserved(0) / 1024 / 1024
memory_allocated = (
torch.get_device_module(device_type).memory_allocated(0) / 1024 / 1024
)
memory_reserved = (
torch.get_device_module(device_type).memory_reserved(0) / 1024 / 1024
)
self.assertTrue(memory_allocated <= 384)
self.assertTrue(memory_reserved <= 768)
@ -942,11 +948,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
meta_submodel = nn.Linear(4, 4, bias=False)
with torch.device("cpu"):
cpu_submodel = nn.Linear(4, 4, bias=False)
with torch.device("cuda"):
cuda_submodel = nn.Linear(4, 4, bias=False)
with torch.device(device_type):
acc_submodel = nn.Linear(4, 4, bias=False)
two_device_model_with_meta = nn.Sequential(meta_submodel, cuda_submodel)
two_device_model_without_meta = nn.Sequential(cpu_submodel, cuda_submodel)
two_device_model_with_meta = nn.Sequential(meta_submodel, acc_submodel)
two_device_model_without_meta = nn.Sequential(cpu_submodel, acc_submodel)
with torch.device("cpu"):
model_to_set = nn.Sequential(
@ -974,7 +980,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def test_state_dict_with_hook_on_keys(self) -> None:
with torch.device("meta"):
metamodel = FusionEmbedding(4, 4, 4)
with torch.device("cuda"):
with torch.device(device_type):
gpumodel = FusionEmbeddingWithHook(4, 4, 4)
gpumodel_state_dict = get_model_state_dict(gpumodel)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
@ -995,8 +1001,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def forward(self, x):
return self.fc1(self.fc(x))
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = TestModel().cuda()
device_mesh = init_device_mesh(device_type, (self.world_size,))
model = TestModel().to(device_type)
parallelize_module(
model,
device_mesh,
@ -1014,7 +1020,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
optim = torch.optim.AdamW(**optim_kwargs)
optim.zero_grad()
model(torch.randn(64, 64).cuda()).sum().backward()
model(torch.randn(64, 64, device=device_type)).sum().backward()
optim.step()
optim.zero_grad()
@ -1067,7 +1073,7 @@ class TestNoComm(MultiProcessTestCase):
@skip_if_lt_x_gpu(1)
def test_no_dist(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model = CompositeParamModel(device=torch.device(device_type))
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
self.assertFalse(dist.is_initialized())

View File

@ -25,6 +25,8 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
DEFAULT_WORLD_SIZE = 4
@ -330,7 +332,7 @@ class TestCollectivesWithBaseClass(MultiThreadedTestCase):
return grad_output * result
x = torch.tensor(
[dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True
[dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True
)
x = MyFunc.apply(x)
x.sum().backward()

View File

@ -22,7 +22,7 @@ class ShardedTensorTestBase(MultiProcessTestCase):
return TEST_GPU_NUM
def init_pg(self, backend="nccl"):
if backend not in ["nccl", "gloo", "mpi"]:
if backend not in ["nccl", "gloo", "mpi", "hccl"]:
raise RuntimeError(f"Backend {backend} not supported!")
dist.init_process_group(