[DTensor] Renamed shard_spec -> placements in test file (#113917)

Public APIs like `from_local` and `distribute_tensor` name the argument as `placements`, not `shard_spec` anymore. This was a direct find and replace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113917
Approved by: https://github.com/wanchaol
ghstack dependencies: #113654, #113903
This commit is contained in:
Andrew Gu
2023-11-16 20:30:45 -08:00
committed by PyTorch MergeBot
parent 8372983fe3
commit e360f4c6dd

View File

@ -45,13 +45,13 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_constructor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3, requires_grad=True)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
dist_tensor = DTensor(
local_tensor,
device_mesh,
shard_spec,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=True,
@ -63,7 +63,7 @@ class DTensorTest(DTensorTestBase):
DTensor(
local_tensor,
device_mesh,
shard_spec,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=False,
@ -75,7 +75,7 @@ class DTensorTest(DTensorTestBase):
dist_tensor = DTensor(
local_tensor,
device_mesh,
shard_spec,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=True,
@ -165,9 +165,9 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_from_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
replica_spec = [Replicate()]
@ -184,14 +184,14 @@ class DTensorTest(DTensorTestBase):
local_tensor_temp = local_tensor_with_grad * 3
# create the dist tensor with non leaf local tensor, dist tensor created
# should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 3
self.assertIsInstance(output, DTensor)
# trigger .backward() on dist tensor directly
local_grad = torch.ones(3, 3)
grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec)
grad_output = DTensor.from_local(local_grad, device_mesh, placements)
# run backward directly on dist tensor
output.backward(grad_output)
# check it gradients flow back to original torch.Tensor
@ -263,15 +263,15 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(-1)]
placements = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
placements = (Shard(0),)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
@ -280,7 +280,7 @@ class DTensorTest(DTensorTestBase):
sharded_tensor = DTensor(
local_tensor_with_grad,
device_mesh,
shard_spec,
placements,
shape=dist_tensor_shape,
dtype=local_tensor_with_grad.dtype,
requires_grad=True,
@ -322,10 +322,10 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
)
@ -337,10 +337,10 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_full_tensor_sync(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
full_out = sharded_dtensor.full_tensor()
self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
self.assertEqual(full_out, global_tensor)
@ -348,10 +348,10 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_full_tensor_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
local_out.sum().backward()
@ -424,7 +424,7 @@ class DTensorTest(DTensorTestBase):
def test_from_local_then_to_local(self):
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
# step 1. construct from construct local tensor
local_tensor_with_grad = torch.randn(
@ -434,7 +434,7 @@ class DTensorTest(DTensorTestBase):
local_tensor_temp = local_tensor_with_grad + 8
# step 2. create the dist tensor with non leaf local tensor, dist tensor
# created should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 6
@ -456,23 +456,23 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_spec_read_only_after_set(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
# modify shard_spec, and dist_tensor's spec should not be changed
shard_spec[0] = Replicate()
self.assertTrue(sharded_tensor.placements is not shard_spec)
self.assertNotEqual(sharded_tensor.placements, shard_spec)
# modify placements, and dist_tensor's spec should not be changed
placements[0] = Replicate()
self.assertTrue(sharded_tensor.placements is not placements)
self.assertNotEqual(sharded_tensor.placements, placements)
@with_comms
def test_dtensor_spec_hash(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
local_tensor2 = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
# note that DTensorSpec without real tensor data, so the hash would be the same
# as long as the mesh, placements and tensor properties are the same
self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
@ -488,9 +488,9 @@ class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_properties(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.device.type, self.device_type)
@with_comms
@ -498,9 +498,9 @@ class DTensorTest(DTensorTestBase):
import io
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
@ -526,25 +526,25 @@ class DTensorMeshTest(DTensorTestBase):
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
def test_dtensor_api_device_mesh_context_manager(self):
with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=shard_spec
local_tensor, device_mesh=mesh, placements=placements
)
with DeviceMesh(self.device_type, list(range(self.world_size))):
shard_spec = [Shard(0)]
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, placements=shard_spec)
sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
replica_spec = [Replicate()]
replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
self.assertEqual(
@ -552,10 +552,10 @@ class DTensorMeshTest(DTensorTestBase):
)
with DeviceMesh(self.device_type, torch.arange(self.world_size)):
shard_spec = [Shard(0)]
placements = [Shard(0)]
global_shape = torch.Size([3 * self.world_size, 3])
global_tensor = torch.randn(global_shape)
sharded_tensor = distribute_tensor(global_tensor, placements=shard_spec)
sharded_tensor = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
mesh_2d = DeviceMesh(
@ -568,7 +568,7 @@ class DTensorMeshTest(DTensorTestBase):
self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
sharded_after_2d = distribute_tensor(global_tensor, placements=shard_spec)
sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
@with_comms
@ -578,9 +578,9 @@ class DTensorMeshTest(DTensorTestBase):
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
shard_spec = [Shard(0), Shard(1)]
placements = [Shard(0), Shard(1)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(
dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
)
@ -600,17 +600,17 @@ class DTensorMeshTest(DTensorTestBase):
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
shard_spec = [Shard(0), Shard(1), Shard(2)]
placements = [Shard(0), Shard(1), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# construct a dist tensor on 3d device mesh with some shards on same dim
shard_spec = [Shard(0), Shard(0), Shard(2)]
placements = [Shard(0), Shard(0), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@ -647,8 +647,8 @@ class DTensorMeshTest(DTensorTestBase):
# loop through all sharding specs and check local shard offsets
logical_tensor = torch.randn(tensor_shape)
for shard_spec, expected_shard_offsets in shard_spec_and_offsets:
dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec)
for placements, expected_shard_offsets in shard_spec_and_offsets:
dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
_, offset = compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, dtensor.placements
)