[2/N] Port 5 _composable distributed test to Intel GPU (#159241)

For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for _composable cases, the first is https://github.com/pytorch/pytorch/pull/159118.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- Use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- Enabled XPU for some test path
- Skip some test cases which Intel GPU does not support
- Added "cpu:gloo,xpu:xccl" for distributed backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159241
Approved by: https://github.com/guangyey, https://github.com/d4l3k
This commit is contained in:
Zeng, Xiangdong
2025-09-15 06:24:55 +00:00
committed by PyTorch MergeBot
parent 06bb32d55e
commit 814ba34fa6
6 changed files with 120 additions and 75 deletions

View File

@ -10,10 +10,13 @@ import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.utils.checkpoint import CheckpointError
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class MemoryDelta(ContextDecorator):
def __init__(self, device: torch.device):
self.device: torch.device = device
@ -22,16 +25,16 @@ class MemoryDelta(ContextDecorator):
def __enter__(self):
self.active_memory_enter = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
torch.accelerator.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda" or self.device.type == "xpu"
else 0
)
return self
def __exit__(self, *exc):
self.active_memory_exit = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
torch.accelerator.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda" or self.device.type == "xpu"
else 0
)
@ -126,7 +129,7 @@ class TestCheckpoint(TestCase):
loss2 = net2(x2).sum()
loss2.backward()
if x.is_cuda:
if x.is_cuda or x.is_xpu:
self.assertTrue(mem2.delta() < mem1.delta())
for p1, p2 in zip(net1.parameters(), net2.parameters()):
@ -137,10 +140,10 @@ class TestCheckpoint(TestCase):
net = ToyModel()
self._test_tensor_only(net, x)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu")
def test_tensor_only_gpu(self):
x = torch.randn(20, 100, device="cuda:0")
net = ToyModel().to("cuda:0")
x = torch.randn(20, 100, device=f"{device_type}:0")
net = ToyModel().to(f"{device_type}:0")
self._test_tensor_only(net, x)
def test_random_cpu(self):

View File

@ -47,6 +47,8 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_XPU,
xfailIf,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -58,6 +60,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
@ -73,7 +78,7 @@ class SimpleModel(nn.Module):
return x
def get_input(self):
return torch.rand(4, 5, device="cuda")
return torch.rand(4, 5, device=device_type)
class SimpleModelUneven(nn.Module):
@ -94,7 +99,7 @@ class SimpleModelUneven(nn.Module):
return x
def get_input(self):
return torch.rand(4, 5, device="cuda")
return torch.rand(4, 5, device=device_type)
class TestFullyShard2DTraining(FSDPTest):
@ -105,13 +110,15 @@ class TestFullyShard2DTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.accelerator.device_count())
def init_global_mesh(self) -> DeviceMesh:
# Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
dp_size = 2 if self.world_size > 2 else 1
return init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
@skip_if_lt_x_gpu(2)
@ -138,7 +145,7 @@ class TestFullyShard2DTraining(FSDPTest):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
@ -150,9 +157,8 @@ class TestFullyShard2DTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
torch.manual_seed(42 + dp_pg.rank() + 1)
device = torch.device("cuda")
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
inp = torch.randn((8, mlp_dim), device=device_type)
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
@ -162,6 +168,7 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
def test_train_parity_2d_transformer(self):
self.run_subtests(
{"use_shard_placement_fn": [False, True]},
@ -172,12 +179,12 @@ class TestFullyShard2DTraining(FSDPTest):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
dp_size, tp_size = self.world_size // 2, 2
global_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
@ -205,7 +212,7 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(full_param, ref_param)
torch.manual_seed(42 + global_mesh.get_local_rank("dp"))
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
for iter_idx in range(5):
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
@ -242,15 +249,16 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(full_param, ref_param)
@skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/156782
def test_tp_with_fsdp_offloading(self):
global_mesh = init_device_mesh(
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
device_type, (1, self.world_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
mlp_dim = 16
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
# Parallelize with N-way TP and 1-way FSDP
model.parallelize(
@ -268,7 +276,7 @@ class TestFullyShard2DTraining(FSDPTest):
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
# called, but they will just be no-ops without issuing any kernels.
# We prefer to keep the no-op check at the c10d level, not in FSDP.
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
inp = torch.randn((4, mlp_dim), device=device_type) # same on all ranks
for _ in range(10):
ref_optim.zero_grad()
optim.zero_grad()
@ -297,6 +305,7 @@ class TestFullyShard2DTraining(FSDPTest):
ref_optim.step()
@skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881
@with_temp_dir
def test_train_parity_2d_transformer_checkpoint_resume(self):
"""
@ -352,7 +361,7 @@ class TestFullyShard2DTraining(FSDPTest):
)
torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (3, 16), device=device_type)
loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
@ -410,14 +419,14 @@ class TestFullyShard2DStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl"
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@with_comms
@skip_if_lt_x_gpu(4)
def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().cuda()
dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh(
"cuda",
device_type,
(2, self.world_size // 2),
mesh_dim_names=("dp", "tp"),
)
@ -561,7 +570,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
model = FSDP(
SimpleModel().cuda(),
SimpleModel().to(device_type),
device_mesh=mesh_2d["dp"],
)
fsdp_state = _get_module_fsdp_state(model)
@ -573,7 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
recompute_activation=False,
) -> None:
torch.manual_seed(0)
model = SimpleModel().cuda(self.rank)
model = SimpleModel().to(f"{device_type}:{self.rank}")
model = FSDP(model, use_orig_params=use_orig_params)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
@ -587,7 +596,9 @@ class TestNew2dParallelTraining(DTensorTestBase):
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan)
model_2d = parallelize_module(
SimpleModel().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP(
model_2d,
device_mesh=dp_mesh,
@ -615,7 +626,7 @@ class TestNew2dParallelTraining(DTensorTestBase):
# Ensure all input across TP ranks are same.
# TODO: add a get_group_rank() to DeviceMesh.
torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0)))
input = torch.rand(4, 5).cuda(self.rank)
input = torch.rand(4, 5).to(f"{device_type}:{self.rank}")
output = model(input)
output_2d = model_2d(input)
self.assertEqual(output, output_2d)
@ -652,7 +663,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl"
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
@with_comms
@skip_if_lt_x_gpu(4)
@ -669,7 +680,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net3": ColwiseParallel(),
}
model_2d = parallelize_module(
SimpleModel().cuda(),
SimpleModel().to(device_type),
mesh_2d["tp"],
parallelize_plan=parallelize_plan,
)
@ -679,8 +690,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions)
)
mesh_1d = init_device_mesh("cuda", (self.world_size,))
model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True)
mesh_1d = init_device_mesh(device_type, (self.world_size,))
model_1d = FSDP(
SimpleModel().to(device_type), device_mesh=mesh_1d, use_orig_params=True
)
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
@ -692,7 +705,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# Create a model without wrapper
torch.manual_seed(0)
no_wrap_model = simple_model().cuda(self.rank)
no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
no_wrap_state_dict = no_wrap_model.state_dict()
# Create a model and sharded it with 2D FSDP + TP
@ -706,7 +719,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
model_2d = parallelize_module(
simple_model().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
FSDP.set_state_dict_type(
@ -754,7 +769,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
model_2d = parallelize_module(
simple_model().to(device_type), tp_mesh, parallelize_plan
)
model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
@ -768,7 +785,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
ref_state_dict = deepcopy(model_2d.state_dict())
# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step()
# Load ref_state_dict back.
@ -799,9 +816,11 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# Create a model without wrapper
torch.manual_seed(0)
no_wrap_model = simple_model().cuda(self.rank)
no_wrap_model = simple_model().to(f"{device_type}:{self.rank}")
no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
no_wrap_model(
no_wrap_model.get_input().to(f"{device_type}:{self.rank}")
).sum().backward()
no_wrap_optim.step()
no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim)
@ -815,7 +834,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
"net2": RowwiseParallel(),
}
model_2d = parallelize_module(
simple_model().cuda(), mesh_2d["tp"], parallelize_plan
simple_model().to(device_type), mesh_2d["tp"], parallelize_plan
)
model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
FSDP.set_state_dict_type(
@ -823,7 +842,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
StateDictType.SHARDED_STATE_DICT,
)
optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step()
optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
ref_optim_2d_osd = deepcopy(optim_2d_osd)
@ -842,7 +861,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
# compare with no_wrap state.
if isinstance(dist_state, DTensor):
dist_state = (
dist_state.cuda()
dist_state.to(device_type)
.redistribute(placements=(Replicate(), Replicate()))
.to_local()
)
@ -850,7 +869,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
self.assertTrue(torch.allclose(state, dist_state))
# Update the parameters 2d optim states will be different from ref_optim_state_dict.
model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward()
optim_2d.step()
set_optimizer_state_dict(
@ -892,8 +911,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
5) dcp.load the state dict from storage
6) load the state dict into the 2D model
"""
dummy_model = SimpleModel().cuda()
mesh_1d = init_device_mesh("cuda", (self.world_size,))
dummy_model = SimpleModel().to(device_type)
mesh_1d = init_device_mesh(device_type, (self.world_size,))
model = FSDP(dummy_model, device_mesh=mesh_1d)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
model(model.get_input()).sum().backward()
@ -911,9 +930,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
dcp.save(state_dict, checkpoint_id=self.temp_dir)
# initialize 2d model
dummy_model = SimpleModel().cuda()
dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh(
"cuda",
device_type,
(2, self.world_size // 2),
mesh_dim_names=("dp", "tp"),
)

View File

@ -30,7 +30,7 @@ from torch.distributed.tensor.parallel import (
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
@ -38,6 +38,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_XPU,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -46,6 +47,10 @@ if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = torch.distributed.get_default_backend_for_device(device_type)
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
@ -79,7 +84,7 @@ class ComposabilityTest(MultiProcessTestCase):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
return backend
def setUp(self):
super().setUp()
@ -100,9 +105,11 @@ class ComposabilityTest(MultiProcessTestCase):
def device(self):
return self.rank
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
)
def test_pp_and_dcp(self):
"""
Test that pipeline parallelism and distributed checkpointing can be used together and
@ -143,11 +150,11 @@ class ComposabilityTest(MultiProcessTestCase):
x = layer(x)
return x
device = torch.device("cuda", self.device)
torch.cuda.set_device(self.device)
device = torch.device(device_type, self.device)
torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
backend="nccl",
backend=backend,
store=store,
rank=self.rank,
world_size=self.world_size,
@ -192,9 +199,11 @@ class ComposabilityTest(MultiProcessTestCase):
_dcp_test(self)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs")
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize(
"ScheduleClass",
[
@ -213,11 +222,11 @@ class ComposabilityTest(MultiProcessTestCase):
],
)
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
_device_raii = torch.device("cuda", self.device)
torch.cuda.set_device(self.device)
_device_raii = torch.device(device_type, self.device)
torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
backend="nccl",
backend=backend,
store=store,
rank=self.rank,
world_size=self.world_size,
@ -228,7 +237,7 @@ class ComposabilityTest(MultiProcessTestCase):
num_microbatches = 8
dp_size = self.world_size // (tp_size * pp_size)
device_mesh = init_device_mesh(
"cuda",
device_type,
mesh_shape=(dp_size, pp_size, tp_size),
mesh_dim_names=("dp", "pp", "tp"),
)

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import os
import unittest
from copy import deepcopy
import torch
@ -14,7 +15,11 @@ from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_module = torch.get_device_module(device_type)
class Net(nn.Module):
@ -154,6 +159,7 @@ class ReplicateTest(MultiProcessTestCase):
self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_move_args_kwargs_to_device(self):
class MyNet(nn.Module):
def __init__(self) -> None:
@ -166,24 +172,25 @@ class ReplicateTest(MultiProcessTestCase):
return self.a(inp)
self._init_pg()
torch.cuda.set_device(self.rank)
model = MyNet().cuda()
replicate(model, device_id=torch.cuda.current_device())
torch.accelerator.set_device_index(self.rank)
model = MyNet().to(device_type)
replicate(model, device_id=torch.accelerator.current_device_index())
# CPU input ensures replicate can move arg and kwargs to device.
a, b = torch.randn(2, 2), torch.randn(2, 2)
model(a, kwarg=b).sum().backward()
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_ignore_module(self):
self._init_pg()
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
# Seed ensures diff input and thus different local grads across ranks.
torch.manual_seed(self.rank)
torch.cuda.manual_seed(self.rank)
model = Net().cuda()
device_module.manual_seed(self.rank)
model = Net().to(device_type)
replicate(model, ignored_modules=[model.fc1])
# CPU input ensures that replicate can move input to GPU as DDP does.
inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
inp = torch.randn(5, 2, device=device_type) * (self.rank + 1)
out = model(inp) * 10
out.sum().backward()
# FC1 grads should not be synchronized, FC2 and 3 should be.
@ -221,10 +228,11 @@ class ReplicateTest(MultiProcessTestCase):
self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_device_id(self):
self._init_pg()
model = Net()
model_cuda = deepcopy(model).cuda()
model_cuda = deepcopy(model).to(device_type)
model_cuda2 = deepcopy(model_cuda)
replicate(model, device_id=torch.device("cpu"))
# DDP instance is attached in first pre forward
@ -233,13 +241,15 @@ class ReplicateTest(MultiProcessTestCase):
# Should be None for CPU training
self.assertEqual(None, replicate_ddp_weakref.device_ids)
replicate(model_cuda, device_id=torch.device(torch.cuda.current_device()))
replicate(
model_cuda, device_id=torch.device(torch.accelerator.current_device_index())
)
# DDP instance is attached in first pre forward
model_cuda(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
self.assertEqual([0], replicate_ddp_weakref.device_ids)
# Pass in int as device_id
replicate(model_cuda2, device_id=int(torch.cuda.current_device()))
replicate(model_cuda2, device_id=int(torch.accelerator.current_device_index()))
# DDP instance is attached in first pre forward
model_cuda2(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
@ -256,6 +266,7 @@ class ReplicateTest(MultiProcessTestCase):
class ReplicateFullyShardInit(ReplicateTest):
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend")
def test_replicate_fully_shard_init(self):
class ToyModel(nn.Module):
def __init__(self, dim: int):
@ -273,14 +284,14 @@ class ReplicateFullyShardInit(ReplicateTest):
return y
self._init_pg()
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
dim = 3
bz = 2
model = ToyModel(dim).cuda()
model = ToyModel(dim).to(device_type)
for linear in model.linears:
fully_shard(linear)
fully_shard(model.linears)
replicate(model, device_id=torch.cuda.current_device())
replicate(model, device_id=torch.accelerator.current_device_index())
for linear in model.linears:
self.assertTrue(isinstance(linear.weight, DTensor))
inp = torch.rand(bz, dim)

View File

@ -98,6 +98,8 @@ class ReplicateTest(MultiProcessInductorTestCase):
self.create_pg(device)
torch._dynamo.config.optimize_ddp = "python_reducer"
torch.manual_seed(123)
if device_type == "xpu":
torch.use_deterministic_algorithms(True, warn_only=True)
model = Net(checkpoint=checkpoint).to(device)
input = torch.randn([1, DIM], device=device)

View File

@ -388,6 +388,7 @@ class DTensorTestBase(MultiProcessTestCase):
"hccl",
"xccl",
"fake",
"cpu:gloo,xpu:xccl",
]:
raise RuntimeError(f"Backend {backend} not supported!")