Compare commits

...

22 Commits

Author SHA1 Message Date
610f9b437d fix another bug 2025-11-14 06:44:20 +00:00
7d0eb9b4f6 fix lint issue 2025-11-14 06:44:20 +00:00
af6ae22dbd fix lint issue 2025-11-14 06:44:20 +00:00
e3afc32110 revert the change 2025-11-14 06:44:20 +00:00
15aa7e01a9 fix the rebase issue 2025-11-14 06:44:20 +00:00
fc5133bacb fix lint issue 2025-11-14 06:44:20 +00:00
1a49d0cda4 update according to review 2025-11-14 06:44:20 +00:00
e9a3814dea revert the change that already in other's pr 2025-11-14 06:44:20 +00:00
2ace9e465a revert change of case 2025-11-14 06:44:20 +00:00
d990b72872 update hpu to acc 2025-11-14 06:44:20 +00:00
a8243bd1d4 update 2025-11-14 06:44:20 +00:00
1ccc757cac skip failed case 2025-11-14 06:44:20 +00:00
2abf4ecf2f port distributed tensor case for Intel GPU 2025-11-14 06:44:20 +00:00
ff3e2942b4 update according to review 2025-11-14 06:44:19 +00:00
a81e5177de revert change of case 2025-11-14 06:44:19 +00:00
f02dba7893 fix a bug 2025-11-14 06:44:19 +00:00
09abf0ceff update hpu to acc 2025-11-14 06:44:19 +00:00
b4d23566db remove redundant skipper 2025-11-14 06:44:19 +00:00
20ca3c48de update 2025-11-14 06:44:19 +00:00
d83d25dee4 skip failed case 2025-11-14 06:44:19 +00:00
528d3fc4ce enable for xpu 2025-11-14 06:44:19 +00:00
fd178b2e17 port distributed tensor case for Intel GPU 2025-11-14 06:44:19 +00:00
6 changed files with 36 additions and 31 deletions

View File

@ -6,7 +6,7 @@ import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch.distributed.tensor import DeviceMesh, DTensor, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
@ -14,6 +14,9 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
c10d_functional = torch.ops.c10d_functional
c10d_ops = torch.ops.c10d
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
class TestCommMode(TestCase):
@ -28,7 +31,7 @@ class TestCommMode(TestCase):
dist.init_process_group(
backend="fake", rank=1, world_size=self.world_size, store=store
)
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
self.device_type = device_type
self.world_pg = dist.distributed_c10d._get_default_group()
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
@ -111,12 +114,12 @@ class TestCommMode(TestCase):
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
def test_comm_mode_with_c10d(self):
if not torch.cuda.is_available():
if not torch.accelerator.is_available():
return
inp = torch.rand(2, 8, 16).cuda()
inp = torch.rand(2, 8, 16).to(device_type)
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
comm_mode = CommDebugMode()

View File

@ -658,11 +658,11 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh
# construct a gpu device mesh
mesh = self.build_device_mesh()
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
# construct from a cpu local tensor with gpu device mesh
# should automatically convert the dist tensor to gpu
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
@ -711,7 +711,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh
# construct a gpu device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
@ -733,7 +733,7 @@ class DTensorMeshTest(DTensorTestBase):
@with_comms
def test_device_mesh_nd(self):
# construct a cuda device mesh
# construct a gpu device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
@ -1064,8 +1064,8 @@ class TestDTensorPlacementTypes(DTensorTestBase):
# Keep everything deterministic.
torch.manual_seed(0)
tensor = torch.rand(size)
if self.device_type == "cuda":
return tensor.cuda()
if self.device_type != "cpu":
return tensor.to(self.device_type)
else:
return tensor

View File

@ -39,6 +39,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
)
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_device_type import skipXPUIf
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
@ -47,8 +48,6 @@ from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
skipIfTorchDynamo,
TEST_CUDA,
TEST_HPU,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -95,6 +94,10 @@ aot_eager_graph = aot_autograd(
partition_fn=min_cut_rematerialization_partition,
)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh):
"""
@ -141,7 +144,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
@property
def device_type(self) -> str:
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
return device_type
@property
def world_size(self) -> int:
@ -160,9 +163,9 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
res = fn(x)
res.to_local().sum().backward()
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available")
def test_dtensor_basic_export(self):
mesh = DeviceMesh("cuda", torch.arange(self.world_size))
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
param = torch.randn(4, 4)
param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
@ -188,10 +191,10 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
)
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
f"""\
def forward(self, b_buffer, x):
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None
to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None
view_as = torch.ops.aten.view_as.default(to, to); to = None
dtensor___init__0 = self.dtensor___init__0
dtensor_const_func_spec0 = self.dtensor_const_func_spec0
@ -206,10 +209,10 @@ def forward(self, b_buffer, x):
# add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add
self.assertExpectedInline(
str(ep.run_decompositions({}).graph_module.code).strip(),
"""\
f"""\
def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
@ -377,6 +380,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
self.assertEqual(res, ref)
@skipIfHpu
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981")
def test_dtensor_dynamic_loss_parallel_log_softmax(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -815,13 +819,13 @@ def forward(self, b_parametrizations_buffer_original0, x):
out = layer_norm.permute(0, 2, 1)
return out
x = torch.randn(4, 2, 4, requires_grad=True, device="cuda")
x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type)
x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False)
y = torch.randn(4, requires_grad=True, device="cuda")
y = torch.randn(4, requires_grad=True, device=self.device_type)
y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False)
z = torch.randn(4, requires_grad=True, device="cuda")
z = torch.randn(4, requires_grad=True, device=self.device_type)
z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
@ -919,7 +923,7 @@ def forward(self, b_parametrizations_buffer_original0, x):
# pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn
def fn(x_dt):
if x_dt.device_mesh.device_type == "cuda":
if x_dt.device_mesh.device_type == f"{self.device_type}":
return x_dt + 1
else:
return x_dt + 2
@ -1051,7 +1055,7 @@ def forward(self, primals_1):
model = FakeTransformer().to(self.device_type)
tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",))
tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",))
# apply sequence parallel
parallel_plan = {

View File

@ -27,8 +27,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_CUDA,
TEST_HPU,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -541,7 +539,7 @@ class RedistributeTest(DTensorTestBase):
local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local()
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
if TEST_HPU or TEST_CUDA:
if torch.accelerator.is_available():
self.assertEqual(
comm_mode.get_comm_counts()[
torch.ops._dtensor.shard_dim_alltoall

View File

@ -296,8 +296,8 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(dist_tensor.dtype, torch.float32)
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_stack(self):
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)

View File

@ -387,7 +387,7 @@ class DTensorTestBase(MultiProcessTestCase):
@property
def backend(self) -> str:
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
backend = dist.get_default_backend_for_device(self.device_type)
return backend
def init_manual_seed_for_rank(self) -> None: