mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA devices so that it can be reset before execution of the op for each such that ops with randomness produces the same result for all ranks (note that we are planning a separate change to add support of per rank rng state). Previously we relied on op input arguments to deduce which devices to get rng state from. Which doesn't work for factory functions such torch.randn. Hence this changes switches to uncondionally collecting rng state from all devices. - Fixing per rank specific computations in _MaskedPartial and Shard placements discovered during test enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
f18041cca8
commit
c4f6619330
@ -17,6 +17,7 @@ from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorConverter,
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -704,6 +705,12 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_dtype_conversion(self):
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_sharding_prop_cache,
|
||||
_get_sharding_prop_cache_info,
|
||||
)
|
||||
|
||||
_clear_sharding_prop_cache()
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
# by default we start from bf16 dtype
|
||||
@ -722,8 +729,6 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
|
||||
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
||||
|
||||
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
|
||||
|
||||
# by this point we only have cache misses
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
self.assertEqual(hits, 0)
|
||||
@ -775,7 +780,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
)
|
||||
|
||||
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
|
||||
torch.manual_seed(self.rank)
|
||||
self.init_manual_seed_for_rank()
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
partial_tensor = torch.randn(8, 8, device=self.device_type)
|
||||
@ -822,5 +827,9 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(x.full_tensor(), y)
|
||||
|
||||
|
||||
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistTensorOpsTest,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -53,7 +53,13 @@ class ProcessGroupTest(TestCase):
|
||||
|
||||
|
||||
class Dist2MultiProcessTestCase(MultiProcessTestCase):
|
||||
device: torch.device
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
raise NotImplementedError
|
||||
|
||||
# @device.setter
|
||||
# def device(self, value: torch.device) -> None:
|
||||
# self._device = value
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -257,7 +263,9 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
|
||||
|
||||
|
||||
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
||||
device = torch.device("cpu")
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
@requires_gloo()
|
||||
def new_group(self) -> torch.distributed.ProcessGroup:
|
||||
@ -274,6 +282,10 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
|
||||
|
||||
|
||||
class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cuda", self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def new_group(self) -> torch.distributed.ProcessGroup:
|
||||
@ -282,8 +294,6 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29501"
|
||||
|
||||
self.device = torch.device("cuda", self.rank)
|
||||
|
||||
return dist2.new_group(
|
||||
backend="nccl",
|
||||
timeout=timedelta(seconds=60),
|
||||
|
Reference in New Issue
Block a user