mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add device generalization support for distributed tests (#156796)
MOTIVATION To generalize Distributed test cases for non-CUDA devices CHANGES - test/distributed/checkpoint/test_fsspec.py - test/distributed/checkpoint/test_state_dict.py - test/distributed/test_multi_threaded_pg.py Replaced hard coded device names with torch.accelerator.current_accelerator - torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py support for hccl backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/156796 Approved by: https://github.com/guangyey, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea74fdd24a
commit
9d184bda2f
@ -18,7 +18,10 @@ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_
|
||||
from torch.distributed.checkpoint.utils import CheckpointException
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
@ -26,6 +29,10 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
BACKEND = torch.distributed.get_default_backend_for_device(device_type)
|
||||
|
||||
|
||||
def with_temp_dir(
|
||||
func: Optional[Callable] = None,
|
||||
) -> Optional[Callable]:
|
||||
@ -75,14 +82,14 @@ class TestFSSpec(ShardedTensorTestBase):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@with_comms(backend=BACKEND, init_rpc=False)
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
@with_temp_dir
|
||||
def test_fsspec(self):
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
model = FSDP(MyTestModule().cuda())
|
||||
model = FSDP(MyTestModule().to(device_type))
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
model(torch.rand(8, 8, device=dist.get_rank())).sum().backward()
|
||||
optim.step()
|
||||
@ -99,7 +106,7 @@ class TestFSSpec(ShardedTensorTestBase):
|
||||
planner=dcp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
model_2 = FSDP(MyTestModule().cuda())
|
||||
model_2 = FSDP(MyTestModule().to(device_type))
|
||||
optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
|
||||
|
||||
with FSDP.summon_full_params(model):
|
||||
@ -149,9 +156,9 @@ class TestFSSpec(ShardedTensorTestBase):
|
||||
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@with_comms(backend=BACKEND, init_rpc=False)
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
@with_temp_dir
|
||||
def test_overwrite(self):
|
||||
t1, t2 = torch.randn(10), torch.randn(10)
|
||||
|
@ -62,6 +62,9 @@ from torch.testing._internal.distributed.common_state_dict import (
|
||||
from torch.utils._pytree import tree_all, tree_all_only
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
@ -79,7 +82,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.accelerator.device_count())
|
||||
|
||||
def _test_save_load(
|
||||
self,
|
||||
@ -101,7 +104,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
for d_optim in _dist_optim:
|
||||
d_optim.zero_grad()
|
||||
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
batch = torch.rand(8, 100, device=device_type)
|
||||
model(batch).sum().backward()
|
||||
dist_model(batch).sum().backward()
|
||||
|
||||
@ -188,9 +191,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def init_model_optim():
|
||||
if use_dtensor:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
|
||||
if wrapping:
|
||||
@ -198,7 +201,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
else:
|
||||
strategy = {UnitModule}
|
||||
if use_dtensor:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
dist_model = FSDP(
|
||||
copy.deepcopy(orig_model),
|
||||
auto_wrap_policy=ModuleWrapPolicy(strategy),
|
||||
@ -258,7 +261,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
foreach: bool = True,
|
||||
):
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
orig_optim = optimizer_class(
|
||||
orig_model.parameters(), lr=1e-4, foreach=foreach
|
||||
)
|
||||
@ -295,7 +298,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
if use_composable:
|
||||
@ -329,7 +332,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
test_frozen: bool = False,
|
||||
) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
if test_frozen:
|
||||
for param in chain(
|
||||
orig_model.u1.parameters(), orig_model.u2.parameters()
|
||||
@ -370,7 +373,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
model_copy = copy.deepcopy(orig_model)
|
||||
@ -385,7 +388,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self._test_single_gpu(torch.optim.AdamW)
|
||||
|
||||
def _test_strict(self, parallelism: str) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
if parallelism == "DDP":
|
||||
model = DDP(model)
|
||||
else:
|
||||
@ -422,8 +425,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def _test_cpu_offload_full_state_dict(
|
||||
self, optimizer_class: type[Optimizer]
|
||||
) -> None:
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
orig_model = CompositeParamModel(device=torch.device(device_type))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
dist_model = FSDP(
|
||||
copy.deepcopy(orig_model),
|
||||
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
|
||||
@ -499,7 +502,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_activation_ckpt_fqns_ddp(self) -> None:
|
||||
"""Tests that activation checkpointing prefixes are removed from module names"""
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
original_keys = get_model_state_dict(model).keys()
|
||||
|
||||
apply_activation_checkpointing(model)
|
||||
@ -518,7 +521,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
|
||||
"""Tests that activation checkpointing prefixes are removed from module names"""
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
original_keys = get_model_state_dict(model).keys()
|
||||
|
||||
apply_activation_checkpointing(model)
|
||||
@ -529,7 +532,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_extra_state(self) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
|
||||
def get_extra_state(self):
|
||||
return "MyState"
|
||||
@ -547,21 +550,21 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_non_persistent_buffers(self) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
model.register_buffer(
|
||||
"dont_save_me", torch.rand(100, device="cuda"), persistent=False
|
||||
"dont_save_me", torch.rand(100, device=device_type), persistent=False
|
||||
)
|
||||
target_model = copy.deepcopy(model)
|
||||
set_model_state_dict(target_model, get_model_state_dict(target_model))
|
||||
self.assertEqual(model.state_dict(), get_model_state_dict(target_model))
|
||||
|
||||
def _test_broadcast_from_rank0(self, wrapper) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
fsdp_model = wrapper(copy.deepcopy(model))
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
batch = torch.rand(8, 100, device=device_type)
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
states, optim_states = get_state_dict(model, optim)
|
||||
@ -631,8 +634,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_broadcast_from_rank0(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
hsdp_device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
hsdp_device_mesh = init_device_mesh(device_type, (2, self.world_size // 2))
|
||||
self.run_subtests(
|
||||
{
|
||||
"wrapper": [
|
||||
@ -654,8 +657,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
# This test verifies that FSDP root is not initialized but we should
|
||||
# still be able to get the state_dict without errors because
|
||||
# fsdp_model.state_dict() will trigger the FSDP initialization.
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
get_model_state_dict(fsdp_model)
|
||||
@ -668,10 +671,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
# "initial_lr" is added to optim_state_dict, but not to the new optim
|
||||
# We test whether "initial_lr" appear in optim after
|
||||
# set_optimizer_state_dict.
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(
|
||||
*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]
|
||||
*[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)]
|
||||
)
|
||||
for layer in model:
|
||||
fully_shard(layer)
|
||||
@ -705,11 +707,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_flattened_osd(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh)
|
||||
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters())
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
batch = torch.rand(8, 100, device=device_type)
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
fsdp_optim.zero_grad()
|
||||
@ -730,7 +732,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
|
||||
|
||||
def _test_deprecate_partial(self) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
|
||||
model_state_dict1 = get_model_state_dict(model)
|
||||
model_state_dict1 = copy.deepcopy(model_state_dict1)
|
||||
@ -783,8 +785,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
|
||||
|
||||
def _test_deprecate_fsdp_api(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
|
||||
with self.assertWarnsRegex(
|
||||
FutureWarning,
|
||||
@ -823,8 +825,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
return output
|
||||
|
||||
def init_model_optim():
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device(device_type))
|
||||
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
|
||||
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
|
||||
@ -905,8 +907,12 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(cpu_model_value, meta_model_value)
|
||||
# Memory allocated and reserved are lower due to the change at _distribute_tensors
|
||||
# from view to clone. This test would fail if with view due to higher memory cost.
|
||||
memory_allocated = torch.cuda.memory_allocated(0) / 1024 / 1024
|
||||
memory_reserved = torch.cuda.memory_reserved(0) / 1024 / 1024
|
||||
memory_allocated = (
|
||||
torch.get_device_module(device_type).memory_allocated(0) / 1024 / 1024
|
||||
)
|
||||
memory_reserved = (
|
||||
torch.get_device_module(device_type).memory_reserved(0) / 1024 / 1024
|
||||
)
|
||||
self.assertTrue(memory_allocated <= 384)
|
||||
self.assertTrue(memory_reserved <= 768)
|
||||
|
||||
@ -942,11 +948,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
meta_submodel = nn.Linear(4, 4, bias=False)
|
||||
with torch.device("cpu"):
|
||||
cpu_submodel = nn.Linear(4, 4, bias=False)
|
||||
with torch.device("cuda"):
|
||||
cuda_submodel = nn.Linear(4, 4, bias=False)
|
||||
with torch.device(device_type):
|
||||
acc_submodel = nn.Linear(4, 4, bias=False)
|
||||
|
||||
two_device_model_with_meta = nn.Sequential(meta_submodel, cuda_submodel)
|
||||
two_device_model_without_meta = nn.Sequential(cpu_submodel, cuda_submodel)
|
||||
two_device_model_with_meta = nn.Sequential(meta_submodel, acc_submodel)
|
||||
two_device_model_without_meta = nn.Sequential(cpu_submodel, acc_submodel)
|
||||
|
||||
with torch.device("cpu"):
|
||||
model_to_set = nn.Sequential(
|
||||
@ -974,7 +980,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def test_state_dict_with_hook_on_keys(self) -> None:
|
||||
with torch.device("meta"):
|
||||
metamodel = FusionEmbedding(4, 4, 4)
|
||||
with torch.device("cuda"):
|
||||
with torch.device(device_type):
|
||||
gpumodel = FusionEmbeddingWithHook(4, 4, 4)
|
||||
gpumodel_state_dict = get_model_state_dict(gpumodel)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
@ -995,8 +1001,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def forward(self, x):
|
||||
return self.fc1(self.fc(x))
|
||||
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model = TestModel().cuda()
|
||||
device_mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
model = TestModel().to(device_type)
|
||||
parallelize_module(
|
||||
model,
|
||||
device_mesh,
|
||||
@ -1014,7 +1020,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
optim = torch.optim.AdamW(**optim_kwargs)
|
||||
optim.zero_grad()
|
||||
model(torch.randn(64, 64).cuda()).sum().backward()
|
||||
model(torch.randn(64, 64, device=device_type)).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
@ -1067,7 +1073,7 @@ class TestNoComm(MultiProcessTestCase):
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_no_dist(self) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
model = CompositeParamModel(device=torch.device(device_type))
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
self.assertFalse(dist.is_initialized())
|
||||
|
@ -25,6 +25,8 @@ from torch.testing._internal.common_distributed import (
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
|
||||
@ -330,7 +332,7 @@ class TestCollectivesWithBaseClass(MultiThreadedTestCase):
|
||||
return grad_output * result
|
||||
|
||||
x = torch.tensor(
|
||||
[dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True
|
||||
[dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True
|
||||
)
|
||||
x = MyFunc.apply(x)
|
||||
x.sum().backward()
|
||||
|
@ -22,7 +22,7 @@ class ShardedTensorTestBase(MultiProcessTestCase):
|
||||
return TEST_GPU_NUM
|
||||
|
||||
def init_pg(self, backend="nccl"):
|
||||
if backend not in ["nccl", "gloo", "mpi"]:
|
||||
if backend not in ["nccl", "gloo", "mpi", "hccl"]:
|
||||
raise RuntimeError(f"Backend {backend} not supported!")
|
||||
|
||||
dist.init_process_group(
|
||||
|
Reference in New Issue
Block a user