mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"
This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb. Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
This commit is contained in:
@ -17,7 +17,6 @@ from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorConverter,
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -705,12 +704,6 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_dtype_conversion(self):
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_sharding_prop_cache,
|
||||
_get_sharding_prop_cache_info,
|
||||
)
|
||||
|
||||
_clear_sharding_prop_cache()
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
# by default we start from bf16 dtype
|
||||
@ -729,6 +722,8 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
|
||||
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
||||
|
||||
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
|
||||
|
||||
# by this point we only have cache misses
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
self.assertEqual(hits, 0)
|
||||
@ -780,7 +775,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
)
|
||||
|
||||
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
|
||||
self.init_manual_seed_for_rank()
|
||||
torch.manual_seed(self.rank)
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
partial_tensor = torch.randn(8, 8, device=self.device_type)
|
||||
@ -827,9 +822,5 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(x.full_tensor(), y)
|
||||
|
||||
|
||||
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistTensorOpsTest,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user