[dtensor] Improve from_local API with run_check (#130289)

as titled, this PR:
1. switch `run_check` to be by default False and add extra doc/comments
   about the correctness guarantee. Since I observed so many calls
forget to use run_check=False, we should simply switch to not perform
metadata check and make our documentation explicit
2. Implement metadata check by picking up the changes from https://github.com/pytorch/pytorch/pull/115229
3. Improve the from_local documentation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130289
Approved by: https://github.com/awgu, https://github.com/wz337
ghstack dependencies: #130286, #130287, #130288
This commit is contained in:
Wanchao Liang
2024-07-13 17:20:00 -07:00
committed by PyTorch MergeBot
parent 3342f3aa4e
commit a7cfe40c9b
3 changed files with 97 additions and 17 deletions

View File

@ -822,6 +822,53 @@ class DTensorMeshTest(DTensorTestBase):
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
)
@with_comms
def test_metadata_consistency_check(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
# Create a local tensor with specific metadata and check dtype change
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor = local_tensor.to(dtype=torch.float64)
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
# Create a local tensor with specific metadata and check requires_grad change
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor.requires_grad = False
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
# Create a local tensor with specific metadata and check stride change
local_tensor = torch.randn(3, 4, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor = local_tensor.t() # transpose changes the stride
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
class TestDTensorPlacementTypes(DTensorTestBase):
@property