mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Renamed shard_spec
-> placements
in test file (#113917)
Public APIs like `from_local` and `distribute_tensor` name the argument as `placements`, not `shard_spec` anymore. This was a direct find and replace. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113917 Approved by: https://github.com/wanchaol ghstack dependencies: #113654, #113903
This commit is contained in:
committed by
PyTorch MergeBot
parent
8372983fe3
commit
e360f4c6dd
@ -45,13 +45,13 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_constructor(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3, requires_grad=True)
|
||||
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
|
||||
dist_tensor = DTensor(
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
shard_spec,
|
||||
placements,
|
||||
shape=dist_tensor_shape,
|
||||
dtype=local_tensor.dtype,
|
||||
requires_grad=True,
|
||||
@ -63,7 +63,7 @@ class DTensorTest(DTensorTestBase):
|
||||
DTensor(
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
shard_spec,
|
||||
placements,
|
||||
shape=dist_tensor_shape,
|
||||
dtype=local_tensor.dtype,
|
||||
requires_grad=False,
|
||||
@ -75,7 +75,7 @@ class DTensorTest(DTensorTestBase):
|
||||
dist_tensor = DTensor(
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
shard_spec,
|
||||
placements,
|
||||
shape=dist_tensor_shape,
|
||||
dtype=local_tensor.dtype,
|
||||
requires_grad=True,
|
||||
@ -165,9 +165,9 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_from_local(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
|
||||
|
||||
replica_spec = [Replicate()]
|
||||
@ -184,14 +184,14 @@ class DTensorTest(DTensorTestBase):
|
||||
local_tensor_temp = local_tensor_with_grad * 3
|
||||
# create the dist tensor with non leaf local tensor, dist tensor created
|
||||
# should also be non leaf node
|
||||
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
|
||||
self.assertFalse(dist_tensor.is_leaf)
|
||||
# do some random operations on dist tensor
|
||||
output = dist_tensor * 3
|
||||
self.assertIsInstance(output, DTensor)
|
||||
# trigger .backward() on dist tensor directly
|
||||
local_grad = torch.ones(3, 3)
|
||||
grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec)
|
||||
grad_output = DTensor.from_local(local_grad, device_mesh, placements)
|
||||
# run backward directly on dist tensor
|
||||
output.backward(grad_output)
|
||||
# check it gradients flow back to original torch.Tensor
|
||||
@ -263,15 +263,15 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_from_local_negative_dim(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(-1)]
|
||||
placements = [Shard(-1)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
self.assertEqual(sharded_tensor.placements[0].dim, 1)
|
||||
|
||||
@with_comms
|
||||
def test_to_local(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
placements = (Shard(0),)
|
||||
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
|
||||
local_tensor_with_grad = torch.randn(
|
||||
3, 3, device=self.device_type, requires_grad=True
|
||||
@ -280,7 +280,7 @@ class DTensorTest(DTensorTestBase):
|
||||
sharded_tensor = DTensor(
|
||||
local_tensor_with_grad,
|
||||
device_mesh,
|
||||
shard_spec,
|
||||
placements,
|
||||
shape=dist_tensor_shape,
|
||||
dtype=local_tensor_with_grad.dtype,
|
||||
requires_grad=True,
|
||||
@ -322,10 +322,10 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_to_local_grad_hint(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
placements = (Shard(0),)
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
|
||||
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
|
||||
grad_placements=[_Partial()]
|
||||
)
|
||||
@ -337,10 +337,10 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_full_tensor_sync(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
placements = (Shard(0),)
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
|
||||
full_out = sharded_dtensor.full_tensor()
|
||||
self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
|
||||
self.assertEqual(full_out, global_tensor)
|
||||
@ -348,10 +348,10 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_full_tensor_grad_hint(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
placements = (Shard(0),)
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
|
||||
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
|
||||
local_out.sum().backward()
|
||||
|
||||
@ -424,7 +424,7 @@ class DTensorTest(DTensorTestBase):
|
||||
def test_from_local_then_to_local(self):
|
||||
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
|
||||
# step 1. construct from construct local tensor
|
||||
local_tensor_with_grad = torch.randn(
|
||||
@ -434,7 +434,7 @@ class DTensorTest(DTensorTestBase):
|
||||
local_tensor_temp = local_tensor_with_grad + 8
|
||||
# step 2. create the dist tensor with non leaf local tensor, dist tensor
|
||||
# created should also be non leaf node
|
||||
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
|
||||
self.assertFalse(dist_tensor.is_leaf)
|
||||
# do some random operations on dist tensor
|
||||
output = dist_tensor * 6
|
||||
@ -456,23 +456,23 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_spec_read_only_after_set(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# modify shard_spec, and dist_tensor's spec should not be changed
|
||||
shard_spec[0] = Replicate()
|
||||
self.assertTrue(sharded_tensor.placements is not shard_spec)
|
||||
self.assertNotEqual(sharded_tensor.placements, shard_spec)
|
||||
# modify placements, and dist_tensor's spec should not be changed
|
||||
placements[0] = Replicate()
|
||||
self.assertTrue(sharded_tensor.placements is not placements)
|
||||
self.assertNotEqual(sharded_tensor.placements, placements)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_spec_hash(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
local_tensor2 = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
|
||||
# note that DTensorSpec without real tensor data, so the hash would be the same
|
||||
# as long as the mesh, placements and tensor properties are the same
|
||||
self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
|
||||
@ -488,9 +488,9 @@ class DTensorTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_properties(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
self.assertEqual(sharded_tensor.device.type, self.device_type)
|
||||
|
||||
@with_comms
|
||||
@ -498,9 +498,9 @@ class DTensorTest(DTensorTestBase):
|
||||
import io
|
||||
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(sharded_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
@ -526,25 +526,25 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
# construct from a cpu local tensor with cuda device mesh
|
||||
# should automatically convert the dist tensor to cuda
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
self.assertEqual(dist_tensor.device.type, self.device_type)
|
||||
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_api_device_mesh_context_manager(self):
|
||||
with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(
|
||||
local_tensor, device_mesh=mesh, placements=shard_spec
|
||||
local_tensor, device_mesh=mesh, placements=placements
|
||||
)
|
||||
|
||||
with DeviceMesh(self.device_type, list(range(self.world_size))):
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, placements=shard_spec)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
|
||||
replica_spec = [Replicate()]
|
||||
replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
|
||||
self.assertEqual(
|
||||
@ -552,10 +552,10 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
)
|
||||
|
||||
with DeviceMesh(self.device_type, torch.arange(self.world_size)):
|
||||
shard_spec = [Shard(0)]
|
||||
placements = [Shard(0)]
|
||||
global_shape = torch.Size([3 * self.world_size, 3])
|
||||
global_tensor = torch.randn(global_shape)
|
||||
sharded_tensor = distribute_tensor(global_tensor, placements=shard_spec)
|
||||
sharded_tensor = distribute_tensor(global_tensor, placements=placements)
|
||||
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
|
||||
|
||||
mesh_2d = DeviceMesh(
|
||||
@ -568,7 +568,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
|
||||
|
||||
sharded_after_2d = distribute_tensor(global_tensor, placements=shard_spec)
|
||||
sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
|
||||
self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
|
||||
|
||||
@with_comms
|
||||
@ -578,9 +578,9 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
# construct a dist tensor on 2d device mesh and test if works
|
||||
shard_spec = [Shard(0), Shard(1)]
|
||||
placements = [Shard(0), Shard(1)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
self.assertEqual(
|
||||
dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
|
||||
)
|
||||
@ -600,17 +600,17 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
# construct a dist tensor on 3d device mesh and test if works
|
||||
shard_spec = [Shard(0), Shard(1), Shard(2)]
|
||||
placements = [Shard(0), Shard(1), Shard(2)]
|
||||
local_tensor = torch.randn(3, 3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
|
||||
self.assertEqual(dist_tensor.device.type, self.device_type)
|
||||
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
|
||||
|
||||
# construct a dist tensor on 3d device mesh with some shards on same dim
|
||||
shard_spec = [Shard(0), Shard(0), Shard(2)]
|
||||
placements = [Shard(0), Shard(0), Shard(2)]
|
||||
local_tensor = torch.randn(3, 3, 3)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
|
||||
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
|
||||
self.assertEqual(dist_tensor.device.type, self.device_type)
|
||||
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
|
||||
@ -647,8 +647,8 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
# loop through all sharding specs and check local shard offsets
|
||||
logical_tensor = torch.randn(tensor_shape)
|
||||
for shard_spec, expected_shard_offsets in shard_spec_and_offsets:
|
||||
dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec)
|
||||
for placements, expected_shard_offsets in shard_spec_and_offsets:
|
||||
dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
|
||||
_, offset = compute_local_shape_and_global_offset(
|
||||
dtensor.shape, device_mesh, dtensor.placements
|
||||
)
|
||||
|
Reference in New Issue
Block a user