Compare commits

...

4 Commits

Author SHA1 Message Date
060be44d24 Revise test case name to test_sharded_tensor_to_cuda 2025-11-13 06:59:56 +00:00
2277914b92 Fix error:
1. Use torch.accelerator when available.
2. Default to cuda when torch.accelerator is not available.
2025-11-13 06:59:56 +00:00
e7f1f4c09a Revise lint errors 2025-11-13 06:59:56 +00:00
e6058040f6 we will port 3 distributed/_shard tests to Intel GPU.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. use "torch.accelerator.current_accelerator()" to determine the accelerator backend
2. enabled XPU for some test path
3. skip some test cases which Intel GPU does not support
2025-11-13 06:59:56 +00:00
3 changed files with 650 additions and 523 deletions

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,10 @@ from itertools import product
import torch
from torch.distributed._shard import _shard_tensor, sharded_tensor
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec, ShardMetadata
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
@ -17,6 +20,13 @@ from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common i
)
if torch.accelerator.is_available():
DEVICE_TYPE = torch.accelerator.current_accelerator().type
else:
# use cuda as default device type for testing when accelerator is not available
DEVICE_TYPE = "cuda"
BACKEND = torch.distributed.get_default_backend_for_device(DEVICE_TYPE)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -28,7 +38,7 @@ if TEST_WITH_DEV_DBG_ASAN:
class TestReshard(ShardedTensorTestBase):
def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size):
torch.manual_seed(0)
local_tensor = torch.rand(*input_size).cuda(self.rank)
local_tensor = torch.rand(*input_size).to(torch.device(self.rank))
st = _shard_tensor(local_tensor, sharding_spec)
st_compare = _shard_tensor(local_tensor, reshard_spec)
st.reshard(reshard_spec)
@ -43,9 +53,9 @@ class TestReshard(ShardedTensorTestBase):
st.local_shards()[0].metadata, st_compare.local_shards()[0].metadata
)
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_sharded_tensor_reshard(self):
dims = [0, 1]
for sharding_dim, reshard_dim in product(dims, dims):
@ -58,9 +68,9 @@ class TestReshard(ShardedTensorTestBase):
self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26])
self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24])
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_sharded_tensor_reshard_errors(self):
specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
spec, reshard_spec = specs[0], specs[1]
@ -69,12 +79,12 @@ class TestReshard(ShardedTensorTestBase):
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[5, 5],
placement="rank:0/cuda:0",
placement=f"rank:0/{DEVICE_TYPE}:0",
),
ShardMetadata(
shard_offsets=[5, 0],
shard_sizes=[5, 5],
placement="rank:1/cuda:1",
placement=f"rank:1/{DEVICE_TYPE}:1",
),
]
)

View File

@ -8,7 +8,10 @@ from torch.distributed._shard import shard_module
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharding_plan import ShardingPlan, ShardingPlanner
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
@ -28,6 +31,13 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
if torch.accelerator.is_available():
DEVICE_TYPE = torch.accelerator.current_accelerator().type
else:
# use cuda as default device type for testing when accelerator is not available
DEVICE_TYPE = "cuda"
BACKEND = torch.distributed.get_default_backend_for_device(DEVICE_TYPE)
# Example ShardingPlanner that chunks every parameter in the module
# to all available devices defined.
@ -37,7 +47,7 @@ class ChunkAllShardingPlanner(ShardingPlanner):
def __init__(self, chunk_dim=0, device_count=0):
self.dim = chunk_dim
self.devices = [f"rank:{i}/cuda:{i}" for i in range(device_count)]
self.devices = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(device_count)]
def build_plan(self, module: nn.Module) -> ShardingPlan:
named_params = module.named_parameters()
@ -49,9 +59,9 @@ class ChunkAllShardingPlanner(ShardingPlanner):
class TestShardingPlan(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_sharding_plan_errors(self):
rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0]
sharding_plan_wrong_plan = ShardingPlan(
@ -61,7 +71,7 @@ class TestShardingPlan(ShardedTensorTestBase):
output_plan={"": rowwise_sharding_spec},
)
megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank)
megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).to(torch.device(self.rank))
with self.assertRaisesRegex(
TypeError, "Only `ShardingSpec` and `Sharder` are supported to shard"
@ -100,12 +110,12 @@ class TestShardingPlan(ShardedTensorTestBase):
# shard the module with the provided sharding plan
shard_module(megatron_lm, sharding_plan_wrong_param_path)
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_custom_sharding_planner(self):
megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank).cuda(
self.rank
megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank).to(
torch.device(self.rank)
)
planner = ChunkAllShardingPlanner(device_count=TEST_GPU_NUM)
sharding_plan = planner.build_plan(megatron_lm)
@ -118,23 +128,23 @@ class TestShardingPlan(ShardedTensorTestBase):
self.assertTrue(isinstance(megatron_lm.fc1.bias, ShardedTensor))
self.assertTrue(isinstance(megatron_lm.fc2.bias, ShardedTensor))
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_shard_module_sub_process_group(self):
megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank)
colwise_sharding_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:2/cuda:2",
"rank:3/cuda:3",
f"rank:2/{DEVICE_TYPE}:2",
f"rank:3/{DEVICE_TYPE}:3",
],
)
rowwise_sharding_spec = ChunkShardingSpec(
dim=1,
placements=[
"rank:2/cuda:2",
"rank:3/cuda:3",
f"rank:2/{DEVICE_TYPE}:2",
f"rank:3/{DEVICE_TYPE}:3",
],
)
sharding_plan = ShardingPlan(