diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 4806c1b71d0d..66d80f604551 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -771,5 +771,40 @@ class TestCPCustomOps(DTensorTestBase): torch.library.opcheck(flex_cp_allgather, example) +class TestSharding(DTensorTestBase): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + @with_comms + def test_context_parallel_shard(self) -> None: + B = 4 + seq_len = 32 + + device_mesh = init_device_mesh( + mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type + ) + freqs_cis = torch.arange(0, seq_len, device=self.device_type) + q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len) + k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len) + v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len) + + load_balancer = _HeadTailLoadBalancer( + seq_len, self.world_size, torch.device(self.device_type) + ) + freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard( + device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer + ) + self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,)) + chunks = freqs_cis.chunk(self.world_size * 2) + self.assertEqual( + freqs_cis_shard, + torch.cat( + [chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0 + ), + ) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 8d0a07bbd97f..9b89563a0ef9 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -1068,10 +1068,16 @@ def _context_parallel_buffers( for buffer, seq_dim in zip(buffers, buffer_seq_dims): if isinstance(buffer, torch.Tensor): # TODO: the load balance doesn't perform error handling. + + # NOTE: assuming batch dim is 0 + if load_balance_indices is not None: - # NOTE: assuming batch dim is 0 + # TODO: we should expclitly ask users to unsqueeze the batch dim. + # But this is a BC breaking ask. + # However, what we have done today is also not very safe. idx_batch_size = load_balance_indices.size(0) - data_batch_size = buffer.size(0) + data_batch_size = buffer.size(0) if seq_dim > 0 else 1 + if idx_batch_size != 1 and idx_batch_size != data_batch_size: raise ValueError( "Cannot rearrange buffer: " @@ -1079,16 +1085,20 @@ def _context_parallel_buffers( f"but buffer has shape {buffer.shape}." ) - for i in range(data_batch_size): - index = ( - load_balance_indices[0] # identical load-balance in batch - if idx_batch_size == 1 - else load_balance_indices[i] + if seq_dim == 0: + buffer = torch.index_select( + buffer, dim=0, index=load_balance_indices[0] ) - buffer_batch_i = torch.index_select( - buffer[i], dim=seq_dim - 1, index=index - ) - buffer[i] = buffer_batch_i + else: + indices = load_balance_indices + if idx_batch_size == 1: + size = [data_batch_size] + list(indices.size())[1:] + indices = indices.expand(*size) + + for i in range(data_batch_size): + buffer[i] = torch.index_select( + buffer[i], dim=seq_dim - 1, index=indices[i] + ) # use DTensor to shard the buffer on sequence dimension, retain the local tensor sharded_buffer = distribute_tensor(