mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] Improve from_local API with run_check (#130289)
as titled, this PR: 1. switch `run_check` to be by default False and add extra doc/comments about the correctness guarantee. Since I observed so many calls forget to use run_check=False, we should simply switch to not perform metadata check and make our documentation explicit 2. Implement metadata check by picking up the changes from https://github.com/pytorch/pytorch/pull/115229 3. Improve the from_local documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/130289 Approved by: https://github.com/awgu, https://github.com/wz337 ghstack dependencies: #130286, #130287, #130288
This commit is contained in:
committed by
PyTorch MergeBot
parent
3342f3aa4e
commit
a7cfe40c9b
@ -822,6 +822,53 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_metadata_consistency_check(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
placements = [Shard(0)]
|
||||
|
||||
# Create a local tensor with specific metadata and check dtype change
|
||||
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
|
||||
|
||||
if self.rank == 0:
|
||||
local_tensor = local_tensor.to(dtype=torch.float64)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
|
||||
|
||||
try:
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
# Create a local tensor with specific metadata and check requires_grad change
|
||||
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
|
||||
|
||||
if self.rank == 0:
|
||||
local_tensor.requires_grad = False
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
|
||||
|
||||
try:
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
# Create a local tensor with specific metadata and check stride change
|
||||
local_tensor = torch.randn(3, 4, requires_grad=True, dtype=torch.float32)
|
||||
|
||||
if self.rank == 0:
|
||||
local_tensor = local_tensor.t() # transpose changes the stride
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
|
||||
|
||||
try:
|
||||
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
|
||||
class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
@property
|
||||
|
Reference in New Issue
Block a user