Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"

This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb.

Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
This commit is contained in:
PyTorch MergeBot
2025-10-18 09:15:49 +00:00
parent 4740ce7787
commit beb6b62e8c
8 changed files with 25 additions and 150 deletions

View File

@ -17,7 +17,6 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorConverter,
DTensorTestBase,
with_comms,
@ -705,12 +704,6 @@ class DistTensorOpsTest(DTensorTestBase):
@with_comms
def test_dtensor_dtype_conversion(self):
from torch.distributed.tensor.debug import (
_clear_sharding_prop_cache,
_get_sharding_prop_cache_info,
)
_clear_sharding_prop_cache()
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
# by default we start from bf16 dtype
@ -729,6 +722,8 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
# by this point we only have cache misses
hits, misses, _, _ = _get_sharding_prop_cache_info()
self.assertEqual(hits, 0)
@ -780,7 +775,7 @@ class DistTensorOpsTest(DTensorTestBase):
)
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
self.init_manual_seed_for_rank()
torch.manual_seed(self.rank)
mesh = self.build_device_mesh()
partial_tensor = torch.randn(8, 8, device=self.device_type)
@ -827,9 +822,5 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(x.full_tensor(), y)
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
DistTensorOpsTest,
)
if __name__ == "__main__":
run_tests()